{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Lf7huAiYp-An"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "cellView": "form",
        "id": "YHz2D-oIqBWa"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "x44FFES-r6y0"
      },
      "source": [
        "# 文本生成联合学习"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iPFgLeZIsZ3Q"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td><a target=\"_blank\" href=\"https://tensorflow.google.cn/federated/tutorials/federated_learning_for_text_generation\"><img src=\"https://tensorflow.google.cn/images/tf_logo_32px.png\">在 TensorFlow.org 上查看 </a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/zh-cn/federated/tutorials/federated_learning_for_text_generation.ipynb\"><img src=\"https://tensorflow.google.cn/images/colab_logo_32px.png\">在 Google Colab 中运行 </a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/zh-cn/federated/tutorials/federated_learning_for_text_generation.ipynb\"><img src=\"https://tensorflow.google.cn/images/GitHub-Mark-32px.png\">在 GitHub 中查看源代码</a></td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KbNz2tuvsAFB"
      },
      "source": [
        "**注**：本 Colab 已通过验证，可与[最新发布版本](https://github.com/tensorflow/federated#compatibility)的 `tensorflow_federated` pip 软件包一起使用，但 Tensorflow Federated 项目仍处于预发布开发阶段，可能无法在 `master` 上运行。\n",
        "\n",
        "本教程以“[图像分类联合学习](federated_learning_for_image_classification.ipynb)”教程中的概念为基础，演示了联合学习的其他几种实用方法。\n",
        "\n",
        "特别是，我们加载了先前训练的 Keras 模型，并使用基于（模拟）分散数据集的联合训练对其进行优化。这一操作非常重要，原因有几点。能够使用序列化模型，便可方便地将联合学习与其他机器学习方法混合使用。此外，该操作可用的预训练模型范围也在不断扩大——例如，由于预训练模型现已广泛可用（请参见 [TF Hub](https://tensorflow.google.cn/hub) 等库），因此基本不需要从头开始训练语言模型。取而代之的有效方式是，从预训练模型开始，使用联合学习对其进行优化，以适应特定应用分散数据的特定特征。\n",
        "\n",
        "在本教程中，我们将从能够生成 ASCII 字符的 RNN 开始，通过联合学习对其进行优化。我们还将展示如何将最终权重反馈给原始 Keras 模型，从而简化使用标准工具进行评估和生成文本的工作。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9LcC1AwjoqfR"
      },
      "outputs": [],
      "source": [
        "#@test {\"skip\": true}\n",
        "!pip install --quiet --upgrade tensorflow_federated"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "ZjDQysatrc2S"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "b'Hello, World!'"
            ]
          },
          "execution_count": 3,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "import collections\n",
        "import functools\n",
        "import os\n",
        "import time\n",
        "\n",
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "import tensorflow_federated as tff\n",
        "\n",
        "np.random.seed(0)\n",
        "\n",
        "# Test the TFF is working:\n",
        "tff.federated_computation(lambda: 'Hello, World!')()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lyICXwVAxvW9"
      },
      "source": [
        "## 加载预训练模型\n",
        "\n",
        "我们加载的模型根据 TensorFlow 教程“[使用 RNN 通过 Eager Execution 生成文本](https://tensorflow.google.cn/tutorials/sequences/text_generation)”进行了预训练。但是，我们没有使用《[莎士比亚全集](http://www.gutenberg.org/files/100/100-0.txt)》作为数据集，而是基于查尔斯·狄更斯的《[双城记](http://www.ibiblio.org/pub/docs/books/gutenberg/9/98/98.txt)》和《[圣诞颂歌](http://www.ibiblio.org/pub/docs/books/gutenberg/4/46/46.txt)》中的文本对模型进行预训练。\n",
        "\n",
        "除扩大了词汇量，我们并没有修改原始教程，所以这一初始模型不是最先进的模型，但它可以产生合理的预测，足以满足我们的教学目的。最终模型使用 `tf.keras.models.save_model(include_optimizer=False)` 保存。\n",
        "\n",
        "在本教程中，我们将使用 TFF 提供的联合版本数据，通过联合学习针对莎士比亚作品微调此模型。\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XgF8e2Ksyq1F"
      },
      "source": [
        "### 生成词汇查找表"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "IlCgQBRVymwR"
      },
      "outputs": [],
      "source": [
        "# A fixed vocabularly of ASCII chars that occur in the works of Shakespeare and Dickens:\n",
        "vocab = list('dhlptx@DHLPTX $(,048cgkoswCGKOSW[_#\\'/37;?bfjnrvzBFJNRVZ\"&amp;*.26:\\naeimquyAEIMQUY]!%)-159\\r')\n",
        "\n",
        "# Creating a mapping from unique characters to indices\n",
        "char2idx = {u:i for i, u in enumerate(vocab)}\n",
        "idx2char = np.array(vocab)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2EH6MFRdzAwd"
      },
      "source": [
        "### 加载预训练模型并生成一些文本"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "id": "iIK674SrtCTm"
      },
      "outputs": [],
      "source": [
        "def load_model(batch_size):\n",
        "  urls = {\n",
        "      1: 'https://storage.googleapis.com/tff-models-public/dickens_rnn.batch1.kerasmodel',\n",
        "      8: 'https://storage.googleapis.com/tff-models-public/dickens_rnn.batch8.kerasmodel'}\n",
        "  assert batch_size in urls, 'batch_size must be in ' + str(urls.keys())\n",
        "  url = urls[batch_size]\n",
        "  local_file = tf.keras.utils.get_file(os.path.basename(url), origin=url)  \n",
        "  return tf.keras.models.load_model(local_file, compile=False)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "WvuwZBX5Ogfd"
      },
      "outputs": [],
      "source": [
        "def generate_text(model, start_string):\n",
        "  # From https://tensorflow.google.cn/tutorials/sequences/text_generation\n",
        "  num_generate = 200\n",
        "  input_eval = [char2idx[s] for s in start_string]\n",
        "  input_eval = tf.expand_dims(input_eval, 0)\n",
        "  text_generated = []\n",
        "  temperature = 1.0\n",
        "\n",
        "  model.reset_states()\n",
        "  for i in range(num_generate):\n",
        "    predictions = model(input_eval)\n",
        "    predictions = tf.squeeze(predictions, 0)\n",
        "    predictions = predictions / temperature\n",
        "    predicted_id = tf.random.categorical(\n",
        "        predictions, num_samples=1)[-1, 0].numpy()\n",
        "    input_eval = tf.expand_dims([predicted_id], 0)\n",
        "    text_generated.append(idx2char[predicted_id])\n",
        "\n",
        "  return (start_string + ''.join(text_generated))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "MGAdStJ5wDPV"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Downloading data from https://storage.googleapis.com/tff-models-public/dickens_rnn.batch1.kerasmodel\n",
            "16195584/16193984 [==============================] - 0s 0us/step\n",
            "16203776/16193984 [==============================] - 0s 0us/step\n",
            "What of TensorFlow Federated, you ask? Sall\n",
            "yesterday. Received the Bailey.\"\n",
            "\n",
            "\"Mr. Lorry, grimmering himself, or low varked thends the winter, and the eyes of Monsieur\n",
            "Defarge. \"Let his mind, hon in his\n",
            "life and message; four declare\n"
          ]
        }
      ],
      "source": [
        "# Text generation requires a batch_size=1 model.\n",
        "keras_model_batch1 = load_model(batch_size=1)\n",
        "print(generate_text(keras_model_batch1, 'What of TensorFlow Federated, you ask? '))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kKMUn-TlgxuP"
      },
      "source": [
        "## 加载并预处理联合莎士比亚数据\n",
        "\n",
        "`tff.simulation.datasets` 软件包提供了各种数据集，这些数据集被拆分成“客户端”，其中每个客户端对应于可能参与联合学习的特定设备上的数据集。\n",
        "\n",
        "这些数据集提供了真实的非独立同分布数据，可在模拟过程中复制基于真实分散数据进行训练的挑战。这些数据的部分预处理是使用 [Leaf 项目](https://arxiv.org/abs/1812.01097) ([GitHub](https://github.com/TalwalkarLab/leaf)) 中的工具完成的。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "id": "di3nStTDg0qc"
      },
      "outputs": [],
      "source": [
        "train_data, test_data = tff.simulation.datasets.shakespeare.load_data()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_iiY65Vv4QNK"
      },
      "source": [
        "`shakespeare.load_data()` 提供的数据集由一系列字符串 `Tensors` 构成，一个字符串代表莎士比亚戏剧中特定角色的一句台词。客户端键由戏剧名和参演角色名构成，例如 `MUCH_ADO_ABOUT_NOTHING_OTHELLO` 即对应于 Othello（奥赛罗）角色在戏剧 *Much Ado About Nothing*（《无事生非》）中的台词。请注意，在真实的联合学习场景中，并不会通过 ID 来标识或跟踪客户端，但对于模拟而言，使用键控数据集则非常实用。\n",
        "\n",
        "例如，我们可以查看 King Lear（《李尔王》）的如下数据："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "id": "FEKiy1ntmmnk"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "tf.Tensor(b'', shape=(), dtype=string)\n",
            "tf.Tensor(b'What?', shape=(), dtype=string)\n"
          ]
        }
      ],
      "source": [
        "# Here the play is \"The Tragedy of King Lear\" and the character is \"King\".\n",
        "raw_example_dataset = train_data.create_tf_dataset_for_client(\n",
        "    'THE_TRAGEDY_OF_KING_LEAR_KING')\n",
        "# To allow for future extensions, each entry x\n",
        "# is an OrderedDict with a single key 'snippets' which contains the text.\n",
        "for x in raw_example_dataset.take(2):\n",
        "  print(x['snippets'])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kUnbI5Hp4sXg"
      },
      "source": [
        "现在，我们使用 `tf.data.Dataset` 转换来准备此数据，用于训练上面加载的字符 RNN。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "id": "9kDkmGe-7No7"
      },
      "outputs": [],
      "source": [
        "# Input pre-processing parameters\n",
        "SEQ_LENGTH = 100\n",
        "BATCH_SIZE = 8\n",
        "BUFFER_SIZE = 100  # For dataset shuffling"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "id": "W95Of6Bwsrfc"
      },
      "outputs": [],
      "source": [
        "# Construct a lookup table to map string chars to indexes,\n",
        "# using the vocab loaded above:\n",
        "table = tf.lookup.StaticHashTable(\n",
        "    tf.lookup.KeyValueTensorInitializer(\n",
        "        keys=vocab, values=tf.constant(list(range(len(vocab))),\n",
        "                                       dtype=tf.int64)),\n",
        "    default_value=0)\n",
        "\n",
        "\n",
        "def to_ids(x):\n",
        "  s = tf.reshape(x['snippets'], shape=[1])\n",
        "  chars = tf.strings.bytes_split(s).values\n",
        "  ids = table.lookup(chars)\n",
        "  return ids\n",
        "\n",
        "\n",
        "def split_input_target(chunk):\n",
        "  input_text = tf.map_fn(lambda x: x[:-1], chunk)\n",
        "  target_text = tf.map_fn(lambda x: x[1:], chunk)\n",
        "  return (input_text, target_text)\n",
        "\n",
        "\n",
        "def preprocess(dataset):\n",
        "  return (\n",
        "      # Map ASCII chars to int64 indexes using the vocab\n",
        "      dataset.map(to_ids)\n",
        "      # Split into individual chars\n",
        "      .unbatch()\n",
        "      # Form example sequences of SEQ_LENGTH +1\n",
        "      .batch(SEQ_LENGTH + 1, drop_remainder=True)\n",
        "      # Shuffle and form minibatches\n",
        "      .shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)\n",
        "      # And finally split into (input, target) tuples,\n",
        "      # each of length SEQ_LENGTH.\n",
        "      .map(split_input_target))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Jw98HnKmEhuh"
      },
      "source": [
        "请注意，在形成原始序列和形成上述批次时，为简单起见，我们使用 `drop_remainder=True`。这意味着文本字符数低于 `(SEQ_LENGTH + 1) * BATCH_SIZE` 的任何角色（客户端）的数据集都将为空。解决此问题的典型方法是使用特殊词例填充批次，然后遮盖损失以忽略填充词例。\n",
        "\n",
        "这会使样本变得有些复杂，所以我们在本教程中仅使用[标准教程](https://tensorflow.google.cn/tutorials/sequences/text_generation)所介绍的完整批次。但在联合环境中，此问题将更为严重，因为可能会有许多用户使用较小的数据集。\n",
        "\n",
        "现在，我们可以预处理我们的 `raw_example_dataset`，并检查类型："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "id": "7rTal7bksWwc"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "(TensorSpec(shape=(8, 100), dtype=tf.int64, name=None), TensorSpec(shape=(8, 100), dtype=tf.int64, name=None))\n"
          ]
        }
      ],
      "source": [
        "example_dataset = preprocess(raw_example_dataset)\n",
        "print(example_dataset.element_spec)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ePT8Oawm8SRP"
      },
      "source": [
        "## 编译模型并基于预处理的数据进行测试"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vEgDsz-48cAq"
      },
      "source": [
        "我们加载了未编译的 Keras 模型，但为了运行 `keras_model.evaluate`，我们需要使用损失和指标对其进行编译。我们还将在一个优化器中编译，该优化器将用作联合学习的设备端优化器。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RsuVZ5KMWnn8"
      },
      "source": [
        "原始教程并不具备字符级准确率（在预测中，最高概率能够落于正确的下一字符的部分）。这是一项有用的指标，因此我们添加了该指标。但是，我们需要为此定义一个新的指标类，因为我们的预测的秩为 3（每个 `BATCH_SIZE * SEQ_LENGTH` 预测的 logits 的向量），而 `SparseCategoricalAccuracy` 仅期望秩为 2 的预测。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {
        "id": "gOUiDBvmWlM9"
      },
      "outputs": [],
      "source": [
        "class FlattenedCategoricalAccuracy(tf.keras.metrics.SparseCategoricalAccuracy):\n",
        "\n",
        "  def __init__(self, name='accuracy', dtype=tf.float32):\n",
        "    super().__init__(name, dtype=dtype)\n",
        "\n",
        "  def update_state(self, y_true, y_pred, sample_weight=None):\n",
        "    y_true = tf.reshape(y_true, [-1, 1])\n",
        "    y_pred = tf.reshape(y_pred, [-1, len(vocab), 1])\n",
        "    return super().update_state(y_true, y_pred, sample_weight)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "U2X9eFgt94PM"
      },
      "source": [
        "现在，我们可以编译模型，并基于我们的 `example_dataset` 对其进行评估。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {
        "id": "c3Xd-52-9zGa"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Downloading data from https://storage.googleapis.com/tff-models-public/dickens_rnn.batch8.kerasmodel\n",
            "16195584/16193984 [==============================] - 0s 0us/step\n",
            "16203776/16193984 [==============================] - 0s 0us/step\n",
            "Evaluating on an example Shakespeare character: 0.402000\n",
            "Expected accuracy for random guessing: 0.012\n",
            "Evaluating on completely random data: 0.011\n"
          ]
        }
      ],
      "source": [
        "BATCH_SIZE = 8  # The training and eval batch size for the rest of this tutorial.\n",
        "keras_model = load_model(batch_size=BATCH_SIZE)\n",
        "keras_model.compile(\n",
        "    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "    metrics=[FlattenedCategoricalAccuracy()])\n",
        "\n",
        "# Confirm that loss is much lower on Shakespeare than on random data\n",
        "loss, accuracy = keras_model.evaluate(example_dataset.take(5), verbose=0)\n",
        "print(\n",
        "    'Evaluating on an example Shakespeare character: {a:3f}'.format(a=accuracy))\n",
        "\n",
        "# As a sanity check, we can construct some completely random data, where we expect\n",
        "# the accuracy to be essentially random:\n",
        "random_guessed_accuracy = 1.0 / len(vocab)\n",
        "print('Expected accuracy for random guessing: {a:.3f}'.format(\n",
        "    a=random_guessed_accuracy))\n",
        "random_indexes = np.random.randint(\n",
        "    low=0, high=len(vocab), size=1 * BATCH_SIZE * (SEQ_LENGTH + 1))\n",
        "data = collections.OrderedDict(\n",
        "    snippets=tf.constant(\n",
        "        ''.join(np.array(vocab)[random_indexes]), shape=[1, 1]))\n",
        "random_dataset = preprocess(tf.data.Dataset.from_tensor_slices(data))\n",
        "loss, accuracy = keras_model.evaluate(random_dataset, steps=10, verbose=0)\n",
        "print('Evaluating on completely random data: {a:.3f}'.format(a=accuracy))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lH0WzL5L8Lm4"
      },
      "source": [
        "## 通过联合学习微调模型"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NCao4M3L_tsA"
      },
      "source": [
        "TFF 会序列化所有 TensorFlow 计算，因此它们有可能会在非 Python 环境中运行（虽然目前只有由 Python 实现的模拟运行时可用）。即使我们以 Eager 模式 (TF 2.0) 运行，当前 TFF 也能够通过在“`with tf.Graph.as_default()`”语句的上下文中构造必要的运算来序列化 TensorFlow 计算。因此，我们需要提供一个函数，供 TFF 将模型引入其控制的计算图中。我们的做法如下："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {
        "id": "5KadIvFp7m6y"
      },
      "outputs": [],
      "source": [
        "# Clone the keras_model inside `create_tff_model()`, which TFF will\n",
        "# call to produce a new copy of the model inside the graph that it will \n",
        "# serialize. Note: we want to construct all the necessary objects we'll need \n",
        "# _inside_ this method.\n",
        "def create_tff_model():\n",
        "  # TFF uses an `input_spec` so it knows the types and shapes\n",
        "  # that your model expects.\n",
        "  input_spec = example_dataset.element_spec\n",
        "  keras_model_clone = tf.keras.models.clone_model(keras_model)\n",
        "  return tff.learning.from_keras_model(\n",
        "      keras_model_clone,\n",
        "      input_spec=input_spec,\n",
        "      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "      metrics=[FlattenedCategoricalAccuracy()])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZJF_yhJxAi2l"
      },
      "source": [
        "现在，我们准备构造一个联合平均迭代过程，用于改进模型（有关联合平均算法的详细信息，请参阅论文《[Communication-Efficient Learning of Deep Networks from Decentralized Data](https://arxiv.org/abs/1602.05629)》）。\n",
        "\n",
        "在每轮联合训练之后，我们使用编译的 Keras 模型执行标准（非联合）评估。在进行模拟联合学习且存在标准测试数据集时，这种操作对于研究目的而言非常实用。\n",
        "\n",
        "在实际的生产环境中，可以使用相同的技术通过联合学习来训练模型，并基于集中式基准数据集对模型进行评估，供测试或质量保证之用。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "metadata": {
        "id": "my3PW3qhAMDA"
      },
      "outputs": [],
      "source": [
        "# This command builds all the TensorFlow graphs and serializes them: \n",
        "fed_avg = tff.learning.build_federated_averaging_process(\n",
        "    model_fn=create_tff_model,\n",
        "    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(lr=0.5))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qVOkzs9C9kmv"
      },
      "source": [
        "以下为最简单的循环，在此循环中，我们在单个批次的单个客户端上运行一轮联合平均："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 17,
      "metadata": {
        "id": "lrjUrkjq9jYk"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "loss=4.403, accuracy=0.132\n"
          ]
        }
      ],
      "source": [
        "state = fed_avg.initialize()\n",
        "state, metrics = fed_avg.next(state, [example_dataset.take(5)])\n",
        "print('loss={l:.3f}, accuracy={a:.3f}'.format(\n",
        "    l=metrics.train.loss, a=metrics.train.accuracy))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "o2CjvVg0FZpS"
      },
      "source": [
        "现在，让我们编写一个更为有趣的训练和评估循环。\n",
        "\n",
        "为了使此模拟仍能相对较快地运行，我们每轮训练三个相同的客户端，每个客户端仅考虑两个迷你批次。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 18,
      "metadata": {
        "id": "wE386-rbMCve"
      },
      "outputs": [],
      "source": [
        "def data(client, source=train_data):\n",
        "  return preprocess(source.create_tf_dataset_for_client(client)).take(5)\n",
        "\n",
        "\n",
        "clients = [\n",
        "    'ALL_S_WELL_THAT_ENDS_WELL_CELIA', 'MUCH_ADO_ABOUT_NOTHING_OTHELLO',\n",
        "]\n",
        "\n",
        "train_datasets = [data(client) for client in clients]\n",
        "\n",
        "# We concatenate the test datasets for evaluation with Keras by creating a \n",
        "# Dataset of Datasets, and then identity flat mapping across all the examples.\n",
        "test_dataset = tf.data.Dataset.from_tensor_slices(\n",
        "    [data(client, test_data) for client in clients]).flat_map(lambda x: x)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cU3FuY00MOoX"
      },
      "source": [
        "`fed_avg.initialize()` 生成的模型的初始状态基于 Keras 模型的随机初始值设定项，而非加载的权重，因为 `clone_model()` 不会克隆权重。要从预训练模型开始训练，我们直接使用加载的模型在服务器状态下设置模型权重。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 22,
      "metadata": {
        "id": "vm_-PU8OFXpY"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Round 0\n",
            "\tEval: loss=3.324, accuracy=0.401\n",
            "\tTrain: loss=4.360, accuracy=0.155\n",
            "Round 1\n",
            "\tEval: loss=4.361, accuracy=0.049\n",
            "\tTrain: loss=4.235, accuracy=0.164\n",
            "Round 2\n",
            "\tEval: loss=4.219, accuracy=0.177\n",
            "\tTrain: loss=4.081, accuracy=0.221\n",
            "Round 3\n",
            "\tEval: loss=4.080, accuracy=0.174\n",
            "\tTrain: loss=3.940, accuracy=0.226\n",
            "Round 4\n",
            "\tEval: loss=3.991, accuracy=0.176\n",
            "\tTrain: loss=3.840, accuracy=0.226\n",
            "Final evaluation\n",
            "\tEval: loss=3.909, accuracy=0.171\n"
          ]
        }
      ],
      "source": [
        "NUM_ROUNDS = 5\n",
        "\n",
        "# The state of the FL server, containing the model and optimization state.\n",
        "state = fed_avg.initialize()\n",
        "\n",
        "state = tff.learning.state_with_new_model_weights(\n",
        "    state,\n",
        "    trainable_weights=[v.numpy() for v in keras_model.trainable_weights],\n",
        "    non_trainable_weights=[\n",
        "        v.numpy() for v in keras_model.non_trainable_weights\n",
        "    ])\n",
        "\n",
        "\n",
        "def keras_evaluate(state, round_num):\n",
        "  # Take our global model weights and push them back into a Keras model to\n",
        "  # use its standard `.evaluate()` method.\n",
        "  keras_model = load_model(batch_size=BATCH_SIZE)\n",
        "  keras_model.compile(\n",
        "      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "    metrics=[FlattenedCategoricalAccuracy()])\n",
        "  tff.learning.assign_weights_to_keras_model(keras_model, state.model)\n",
        "  loss, accuracy = keras_model.evaluate(example_dataset, steps=2, verbose=0)\n",
        "  print('\\tEval: loss={l:.3f}, accuracy={a:.3f}'.format(l=loss, a=accuracy))\n",
        "\n",
        "for round_num in range(NUM_ROUNDS):\n",
        "  print('Round {r}'.format(r=round_num))\n",
        "  keras_evaluate(state, round_num)\n",
        "  state, metrics = fed_avg.next(state, train_datasets)\n",
        "  print('\\tTrain: loss={l:.3f}, accuracy={a:.3f}'.format(\n",
        "      l=metrics.train.loss, a=metrics.train.accuracy))\n",
        "\n",
        "keras_evaluate(state, NUM_ROUNDS + 1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SoshvcHhXVa6"
      },
      "source": [
        "我们仅做默认更改，没有进行足够的训练来实现大幅调整，但是如果您使用更大量的莎士比亚数据进行更长时间的训练，那么您应该会看到更新后的模型所生成的文本风格会有所不同："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 23,
      "metadata": {
        "id": "NTUig7QmXavy"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "What of TensorFlow Federated, you ask? Shalways, I will call your\n",
            "compet with any city brought their faces uncompany,\" besumed him. \"When he\n",
            "sticked Madame Defarge pushed the lamps.\n",
            "\n",
            "\"Have I often but no unison. She had probably come,\n"
          ]
        }
      ],
      "source": [
        "# Set our newly trained weights back in the originally created model.\n",
        "keras_model_batch1.set_weights([v.numpy() for v in keras_model.weights])\n",
        "# Text generation requires batch_size=1\n",
        "print(generate_text(keras_model_batch1, 'What of TensorFlow Federated, you ask? '))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4DA1Fkf5mN0s"
      },
      "source": [
        "## 建议的扩展学习\n",
        "\n",
        "本教程只是第一步！以下是有关如何扩展此笔记本的一些想法：\n",
        "\n",
        "- 编写一个更为真实的训练循环，对客户端进行抽样以实现随机训练。\n",
        "- 基于客户端数据集使用“`.repeat(NUM_EPOCHS)`”尝试多个周期的本地训练（例如，[McMahan 等人](https://arxiv.org/abs/1602.05629)所述方法）。另请参阅[图像分类联合学习](federated_learning_for_image_classification.ipynb)，其中提供了相关内容。\n",
        "- 更改 `compile()` 命令以在客户端上尝试使用不同的优化算法。\n",
        "- 尝试针对 `build_federated_averaging_process` 使用 `server_optimizer` 参数以尝试在服务器上应用模型更新的不同算法。\n",
        "- 尝试针对 `build_federated_averaging_process` 使用 `client_weight_fn` 参数数以尝试不同的客户端权重。默认权重客户端会根据客户端上的样本量进行更新，但是您可以执行以下操作：`client_weight_fn=lambda _: tf.constant(1.0)`。"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "federated_learning_for_text_generation.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
